import torch

import d2l


def corr2d_multi_in(X, K):
    return sum(d2l.corr2d(x, k) for x, k in zip(X, K))


X = torch.tensor([
    [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [6.0, 7.0, 8.0]],
    [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]
])

K = torch.tensor([
    [[0.0, 1.0], [2.0, 3.0]],
    [[1.0, 2.0], [3.0, 4.0]]
])

print(corr2d_multi_in(X, K))
